{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train\n",
    "# !wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse(file):\n",
    "    with open(file) as fopen:\n",
    "        texts = fopen.read().split('\\n')\n",
    "    left, right = [], []\n",
    "    for text in texts:\n",
    "        if '-DOCSTART-' in text or not len(text):\n",
    "            continue\n",
    "        splitted = text.split()\n",
    "        left.append(splitted[0])\n",
    "        right.append(splitted[1])\n",
    "    return left, right"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "left_train, right_train = parse('eng.train')\n",
    "left_test, right_test = parse('eng.testa')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_string(string):\n",
    "    string = re.sub('[^A-Za-z0-9\\-\\/ ]+', ' ', string).split()\n",
    "    return ' '.join([to_title(y.strip()) for y in string])\n",
    "\n",
    "def to_title(string):\n",
    "    if string.isupper():\n",
    "        string = string.title()\n",
    "    return string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array(['\"', '$', \"''\", '(', ')', ',', '.', ':', 'CC', 'CD', 'DT', 'EX',\n",
       "        'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNP', 'NNPS',\n",
       "        'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS',\n",
       "        'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ',\n",
       "        'WDT', 'WP', 'WP$', 'WRB'], dtype='<U6'),\n",
       " array([ 2178,   427,    35,  2866,  2866,  7291,  7389,  2386,  3653,\n",
       "        19704, 13453,   136,   166, 19064, 11831,   382,   254,    13,\n",
       "         1199, 23899, 34392,   684,  9903,     4,    33,  1553,  3163,\n",
       "         1520,  3975,   163,    35,   528,   439,  3469,    30,  4252,\n",
       "         8293,  2585,  4105,  1436,  2426,   506,   528,    23,   384]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(right_train,return_counts=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "word2idx = {'PAD': 0,'NUM':1,'UNK':2}\n",
    "tag2idx = {'PAD': 0}\n",
    "char2idx = {'PAD': 0}\n",
    "word_idx = 3\n",
    "tag_idx = 1\n",
    "char_idx = 1\n",
    "\n",
    "def parse_XY(texts, labels):\n",
    "    global word2idx, tag2idx, char2idx, word_idx, tag_idx, char_idx\n",
    "    X, Y = [], []\n",
    "    for no, text in enumerate(texts):\n",
    "        text = text.lower()\n",
    "        tag = labels[no]\n",
    "        for c in text:\n",
    "            if c not in char2idx:\n",
    "                char2idx[c] = char_idx\n",
    "                char_idx += 1\n",
    "        if tag not in tag2idx:\n",
    "            tag2idx[tag] = tag_idx\n",
    "            tag_idx += 1\n",
    "        Y.append(tag2idx[tag])\n",
    "        if text not in word2idx:\n",
    "            word2idx[text] = word_idx\n",
    "            word_idx += 1\n",
    "        X.append(word2idx[text])\n",
    "    return X, np.array(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_Y = parse_XY(left_train, right_train)\n",
    "test_X, test_Y = parse_XY(left_test, right_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx2word = {idx: tag for tag, idx in word2idx.items()}\n",
    "idx2tag = {i: w for w, i in tag2idx.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_len = 50\n",
    "def iter_seq(x):\n",
    "    return np.array([x[i: i+seq_len] for i in range(0, len(x)-seq_len, 1)])\n",
    "\n",
    "def to_train_seq(*args):\n",
    "    return [iter_seq(x) for x in args]\n",
    "\n",
    "def generate_char_seq(batch):\n",
    "    x = [[len(idx2word[i]) for i in k] for k in batch]\n",
    "    maxlen = max([j for i in x for j in i])\n",
    "    temp = np.zeros((batch.shape[0],batch.shape[1],maxlen),dtype=np.int32)\n",
    "    for i in range(batch.shape[0]):\n",
    "        for k in range(batch.shape[1]):\n",
    "            for no, c in enumerate(idx2word[batch[i,k]]):\n",
    "                temp[i,k,-1-no] = char2idx[c]\n",
    "    return temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(203571, 50)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_seq, Y_seq = to_train_seq(train_X, train_Y)\n",
    "X_char_seq = generate_char_seq(X_seq)\n",
    "X_seq.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(51312, 50)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_seq_test, Y_seq_test = to_train_seq(test_X, test_Y)\n",
    "X_char_seq_test = generate_char_seq(X_seq_test)\n",
    "X_seq_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_Y, train_char = X_seq, Y_seq, X_char_seq\n",
    "test_X, test_Y, test_char = X_seq_test, Y_seq_test, X_char_seq_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def layer_norm(inputs, epsilon=1e-8):\n",
    "    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)\n",
    "    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))\n",
    "\n",
    "    params_shape = inputs.get_shape()[-1:]\n",
    "    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())\n",
    "    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())\n",
    "    \n",
    "    outputs = gamma * normalized + beta\n",
    "    return outputs\n",
    "\n",
    "def multihead_attn(queries, keys, q_masks, k_masks, future_binding, num_units, num_heads):\n",
    "    \n",
    "    T_q = tf.shape(queries)[1]                                      \n",
    "    T_k = tf.shape(keys)[1]                  \n",
    "\n",
    "    Q = tf.layers.dense(queries, num_units, name='Q')                              \n",
    "    K_V = tf.layers.dense(keys, 2*num_units, name='K_V')    \n",
    "    K, V = tf.split(K_V, 2, -1)        \n",
    "\n",
    "    Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)                         \n",
    "    K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)                    \n",
    "    V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)                      \n",
    "\n",
    "    align = tf.matmul(Q_, tf.transpose(K_, [0,2,1]))                      \n",
    "    align = align / np.sqrt(K_.get_shape().as_list()[-1])                 \n",
    "\n",
    "    paddings = tf.fill(tf.shape(align), float('-inf'))                   \n",
    "\n",
    "    key_masks = k_masks                                                 \n",
    "    key_masks = tf.tile(key_masks, [num_heads, 1])                       \n",
    "    key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, T_q, 1])            \n",
    "    align = tf.where(tf.equal(key_masks, 0), paddings, align)\n",
    "    if future_binding:\n",
    "        lower_tri = tf.ones([T_q, T_k])                                          \n",
    "        lower_tri = tf.linalg.LinearOperatorLowerTriangular(lower_tri).to_dense()  \n",
    "        masks = tf.tile(tf.expand_dims(lower_tri,0), [tf.shape(align)[0], 1, 1]) \n",
    "        align = tf.where(tf.equal(masks, 0), paddings, align)                      \n",
    "    \n",
    "    align = tf.nn.softmax(align)                                            \n",
    "    query_masks = tf.to_float(q_masks)                                             \n",
    "    query_masks = tf.tile(query_masks, [num_heads, 1])                             \n",
    "    query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, T_k])            \n",
    "    align *= query_masks\n",
    "    outputs = tf.matmul(align, V_)                                                 \n",
    "    outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)             \n",
    "    outputs += queries                                                             \n",
    "    outputs = layer_norm(outputs)                                                 \n",
    "    return outputs\n",
    "\n",
    "\n",
    "def pointwise_feedforward(inputs, hidden_units, activation=None):\n",
    "    outputs = tf.layers.dense(inputs, 4*hidden_units, activation=activation)\n",
    "    outputs = tf.layers.dense(outputs, hidden_units, activation=None)\n",
    "    outputs += inputs\n",
    "    outputs = layer_norm(outputs)\n",
    "    return outputs\n",
    "\n",
    "\n",
    "def learned_position_encoding(inputs, mask, embed_dim):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    outputs = tf.range(tf.shape(inputs)[1])                # (T_q)\n",
    "    outputs = tf.expand_dims(outputs, 0)                   # (1, T_q)\n",
    "    outputs = tf.tile(outputs, [tf.shape(inputs)[0], 1])   # (N, T_q)\n",
    "    outputs = embed_seq(outputs, T, embed_dim, zero_pad=False, scale=False)\n",
    "    return tf.expand_dims(tf.to_float(mask), -1) * outputs\n",
    "\n",
    "def sinusoidal_position_encoding(inputs, mask, repr_dim):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])\n",
    "    i = np.arange(0, repr_dim, 2, np.float32)\n",
    "    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])\n",
    "    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)\n",
    "    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1]) * tf.expand_dims(tf.to_float(mask), -1)\n",
    "\n",
    "def label_smoothing(inputs, epsilon=0.1):\n",
    "    C = inputs.get_shape().as_list()[-1]\n",
    "    return ((1 - epsilon) * inputs) + (epsilon / C)\n",
    "\n",
    "\n",
    "class Model:\n",
    "    def __init__(self,\n",
    "                 dim_word,\n",
    "                 dim_char,\n",
    "                 dropout,\n",
    "                 learning_rate,\n",
    "                 hidden_size_char,\n",
    "                 hidden_size_word,\n",
    "                 num_blocks = 2,\n",
    "                 num_heads = 8,\n",
    "                 min_freq = 50):\n",
    "        \n",
    "        self.word_ids = tf.placeholder(tf.int32, shape = [None, None])\n",
    "        self.char_ids = tf.placeholder(tf.int32, shape = [None, None, None])\n",
    "        self.labels = tf.placeholder(tf.int32, shape = [None, None])\n",
    "        self.maxlen = tf.shape(self.word_ids)[1]\n",
    "        self.lengths = tf.count_nonzero(self.word_ids, 1)\n",
    "        batch_size = tf.shape(self.word_ids)[0]\n",
    "        \n",
    "        self.word_embeddings = tf.Variable(\n",
    "            tf.truncated_normal(\n",
    "                [len(word2idx), dim_word], stddev = 1.0 / np.sqrt(dim_word)\n",
    "            )\n",
    "        )\n",
    "        self.char_embeddings = tf.Variable(\n",
    "            tf.truncated_normal(\n",
    "                [len(char2idx), dim_char], stddev = 1.0 / np.sqrt(dim_char)\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        word_embedded = tf.nn.embedding_lookup(\n",
    "            self.word_embeddings, self.word_ids\n",
    "        )\n",
    "        char_embedded = tf.nn.embedding_lookup(\n",
    "            self.char_embeddings, self.char_ids\n",
    "        )\n",
    "        s = tf.shape(char_embedded)\n",
    "        char_embedded = tf.reshape(\n",
    "            char_embedded, shape = [s[0] * s[1], s[-2], dim_char]\n",
    "        )\n",
    "        reshape_char = tf.reshape(self.char_ids, shape = [s[0] * s[1], s[-2]])\n",
    "        char_masked = tf.sign(reshape_char)\n",
    "        char_embedded += sinusoidal_position_encoding(reshape_char, char_masked, dim_char)\n",
    "        for i in range(num_blocks):\n",
    "            with tf.variable_scope('char_%d'%i,reuse=tf.AUTO_REUSE):\n",
    "                char_embedded = multihead_attn(queries = char_embedded,\n",
    "                                                 keys = char_embedded,\n",
    "                                                 q_masks = char_masked,\n",
    "                                                 k_masks = char_masked,\n",
    "                                                 future_binding = False,\n",
    "                                                 num_units = dim_char,\n",
    "                                                 num_heads = num_heads)\n",
    "            with tf.variable_scope('char_feedforward_%d'%i,reuse=tf.AUTO_REUSE):\n",
    "                char_embedded = pointwise_feedforward(char_embedded,\n",
    "                                                    dim_char,\n",
    "                                                    activation = tf.nn.relu)\n",
    "        output = tf.reshape(\n",
    "            char_embedded[:, -1], shape = [s[0], s[1], 2 * hidden_size_char]\n",
    "        )\n",
    "        word_embedded = tf.concat([word_embedded, output], axis = -1)\n",
    "        word_embedded = tf.layers.dense(word_embedded, dim_char)\n",
    "        de_masks = tf.sign(self.word_ids)\n",
    "        word_embedded += sinusoidal_position_encoding(self.word_ids, de_masks, dim_char)\n",
    "        \n",
    "        for i in range(num_blocks):\n",
    "            with tf.variable_scope('word_char_%d'%i,reuse=tf.AUTO_REUSE):\n",
    "                decoder_embedded = multihead_attn(queries = word_embedded,\n",
    "                                         keys = word_embedded,\n",
    "                                         q_masks = de_masks,\n",
    "                                         k_masks = de_masks,\n",
    "                                         future_binding = True,\n",
    "                                         num_units = dim_char,\n",
    "                                         num_heads = num_heads)\n",
    "                \n",
    "            with tf.variable_scope('word_char_attention_%d'%i,reuse=tf.AUTO_REUSE):\n",
    "                decoder_embedded = multihead_attn(queries = decoder_embedded,\n",
    "                                         keys = output,\n",
    "                                         q_masks = de_masks,\n",
    "                                         k_masks = de_masks,\n",
    "                                         future_binding = False,\n",
    "                                         num_units = dim_char,\n",
    "                                         num_heads = num_heads)\n",
    "            \n",
    "            with tf.variable_scope('word_feedforward_%d'%i,reuse=tf.AUTO_REUSE):\n",
    "                decoder_embedded = pointwise_feedforward(decoder_embedded,\n",
    "                                                    dim_char,\n",
    "                                            activation = tf.nn.relu)\n",
    "        logits = tf.layers.dense(decoder_embedded, len(idx2tag))\n",
    "        y_t = self.labels\n",
    "        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(\n",
    "            logits, y_t, self.lengths\n",
    "        )\n",
    "        self.cost = tf.reduce_mean(-log_likelihood)\n",
    "        self.optimizer = tf.train.AdamOptimizer(\n",
    "            learning_rate = learning_rate\n",
    "        ).minimize(self.cost)\n",
    "        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)\n",
    "        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(\n",
    "            logits, transition_params, self.lengths\n",
    "        )\n",
    "        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')\n",
    "\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(self.tags_seq, mask)\n",
    "        mask_label = tf.boolean_mask(y_t, mask)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From <ipython-input-14-48fb560e4b61>:70: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n",
      "WARNING:tensorflow:From <ipython-input-14-48fb560e4b61>:17: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dense instead.\n",
      "\n",
      "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/contrib/crf/python/ops/crf.py:213: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/rnn.py:626: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
      "  \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "\n",
    "dim_word = 64\n",
    "dim_char = 128\n",
    "dropout = 0.8\n",
    "learning_rate = 1e-3\n",
    "hidden_size_char = 64\n",
    "hidden_size_word = 64\n",
    "num_layers = 2\n",
    "batch_size = 32\n",
    "\n",
    "model = Model(dim_word,dim_char,dropout,learning_rate,hidden_size_char,hidden_size_word,num_layers)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6362/6362 [22:54<00:00,  5.04it/s, accuracy=0.951, cost=8.84] \n",
      "test minibatch loop: 100%|██████████| 1604/1604 [02:16<00:00, 11.76it/s, accuracy=0.85, cost=17.6] \n",
      "train minibatch loop:   0%|          | 0/6362 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 1510.839835882187\n",
      "epoch: 0, training loss: 23.373860, training acc: 0.854747, valid loss: 23.969787, valid acc: 0.857473\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6362/6362 [22:42<00:00,  5.10it/s, accuracy=0.975, cost=8.44] \n",
      "test minibatch loop: 100%|██████████| 1604/1604 [02:16<00:00, 12.07it/s, accuracy=0.895, cost=11.8]\n",
      "train minibatch loop:   0%|          | 0/6362 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 1498.9244010448456\n",
      "epoch: 1, training loss: 9.763392, training acc: 0.935275, valid loss: 19.811699, valid acc: 0.886606\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6362/6362 [22:41<00:00,  4.98it/s, accuracy=0.977, cost=7.45] \n",
      "test minibatch loop: 100%|██████████| 1604/1604 [02:16<00:00, 12.05it/s, accuracy=0.906, cost=10.7] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 1498.017891407013\n",
      "epoch: 2, training loss: 7.485715, training acc: 0.949510, valid loss: 20.068659, valid acc: 0.889755\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "for e in range(3):\n",
    "    lasttime = time.time()\n",
    "    train_acc, train_loss, test_acc, test_loss = 0, 0, 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = train_X[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_char = train_char[i : min(i + batch_size, train_X.shape[0])]\n",
    "        batch_y = train_Y[i : min(i + batch_size, train_X.shape[0])]\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.word_ids: batch_x,\n",
    "                model.char_ids: batch_char,\n",
    "                model.labels: batch_y\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss += cost\n",
    "        train_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'test minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_char = test_char[i : min(i + batch_size, test_X.shape[0])]\n",
    "        batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.word_ids: batch_x,\n",
    "                model.char_ids: batch_char,\n",
    "                model.labels: batch_y\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        test_loss += cost\n",
    "        test_acc += acc\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    train_loss /= len(train_X) / batch_size\n",
    "    train_acc /= len(train_X) / batch_size\n",
    "    test_loss /= len(test_X) / batch_size\n",
    "    test_acc /= len(test_X) / batch_size\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (e, train_loss, train_acc, test_loss, test_acc)\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred2label(pred):\n",
    "    out = []\n",
    "    for pred_i in pred:\n",
    "        out_i = []\n",
    "        for p in pred_i:\n",
    "            out_i.append(idx2tag[p])\n",
    "        out.append(out_i)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 1604/1604 [02:03<00:00, 12.95it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    batch_x = test_X[i : min(i + batch_size, test_X.shape[0])]\n",
    "    batch_char = test_char[i : min(i + batch_size, test_X.shape[0])]\n",
    "    batch_y = test_Y[i : min(i + batch_size, test_X.shape[0])]\n",
    "    predicted = pred2label(sess.run(model.tags_seq,\n",
    "            feed_dict = {\n",
    "                model.word_ids: batch_x,\n",
    "                model.char_ids: batch_char,\n",
    "            },\n",
    "    ))\n",
    "    real = pred2label(batch_y)\n",
    "    predict_Y.extend(predicted)\n",
    "    real_Y.extend(real)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "             precision    recall  f1-score   support\n",
      "\n",
      "          \"       1.00      1.00      1.00     32000\n",
      "          $       1.00      1.00      1.00      5050\n",
      "         ''       0.86      0.47      0.61       550\n",
      "          (       1.00      1.00      1.00     33950\n",
      "          )       1.00      1.00      1.00     34000\n",
      "          ,       1.00      1.00      1.00     97450\n",
      "          .       1.00      1.00      1.00     93842\n",
      "          :       1.00      1.00      1.00     31105\n",
      "         CC       1.00      1.00      1.00     46551\n",
      "         CD       0.91      0.99      0.95    214444\n",
      "         DT       1.00      0.99      0.99    175916\n",
      "         EX       0.96      0.67      0.79      1998\n",
      "         FW       0.24      0.67      0.35      1450\n",
      "         IN       0.97      0.98      0.97    248430\n",
      "         JJ       0.75      0.82      0.78    152092\n",
      "        JJR       0.53      0.79      0.63      5250\n",
      "        JJS       0.25      0.84      0.38      3900\n",
      "         LS       0.00      0.00      0.00        50\n",
      "         MD       0.94      0.99      0.96     15000\n",
      "         NN       0.86      0.80      0.83    300941\n",
      "        NNP       0.82      0.86      0.84    427200\n",
      "       NNPS       0.60      0.28      0.38      8300\n",
      "        NNS       0.86      0.75      0.80    125063\n",
      "     NN|SYM       0.56      0.94      0.70        50\n",
      "        PDT       0.56      0.81      0.66       350\n",
      "        POS       0.97      0.99      0.98     21150\n",
      "        PRP       0.99      0.98      0.99     43100\n",
      "       PRP$       0.97      1.00      0.98     21086\n",
      "         RB       0.69      0.86      0.77     49446\n",
      "        RBR       0.63      0.45      0.52      2650\n",
      "        RBS       0.86      0.59      0.70       900\n",
      "         RP       0.73      0.51      0.60      7500\n",
      "        SYM       0.92      0.76      0.83      4300\n",
      "         TO       1.00      1.00      1.00     45288\n",
      "         UH       0.00      0.00      0.00       250\n",
      "         VB       0.81      0.86      0.84     55934\n",
      "        VBD       0.91      0.89      0.90    111414\n",
      "        VBG       0.91      0.33      0.49     35000\n",
      "        VBN       0.84      0.69      0.75     49650\n",
      "        VBP       0.88      0.73      0.79     18250\n",
      "        VBZ       0.87      0.76      0.81     25550\n",
      "        WDT       0.96      0.80      0.87      7750\n",
      "         WP       0.90      0.98      0.94      6350\n",
      "        WP$       1.00      1.00      1.00       450\n",
      "        WRB       0.91      0.97      0.94      4650\n",
      "\n",
      "avg / total       0.89      0.89      0.89   2565600\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jupyter/.local/lib/python3.6/site-packages/sklearn/metrics/classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n",
      "  'precision', 'predicted', average, warn_for)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(np.array(real_Y).ravel(), np.array(predict_Y).ravel()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
